import json
from openai import OpenAI
import time
import re

client = OpenAI(
    base_url="https://api.deepseek.com",
    api_key="",  # Replace with your actual API key.
)

# Define instructions
style_instruction = "Please act as an impartial judge and evaluate the quality of the generation story contents provided by an AI assistant. Your job is to give a score out of 10. Your evaluation should consider the style consistency of the story text. Do not allow the length of the responses to influence your evaluation. Be as objective as possible. After providing your explanation, output your final score by strictly following this format: \"[[score]]\", such as \"[[7]]\"."

engage_instruction =  "Please act as an impartial judge and evaluate the quality of the generation story contents provided by an AI assistant. Your job is to give a score out of 10. Your evaluation should consider the engaging level of the story. Do not allow the length of the responses to influence your evaluation. Be as objective as possible. After providing your explanation, output your final score by strictly following this format: \"[[score]]\", such as \"[[7]]\"."

coherence_instruction = "Please act as an impartial judge and evaluate how closely the generated captions align with the original captions in terms of meaning, style, and relevance to the story context. After providing your explanation, output your final score by strictly following this format: \"[[score]]\", such as \"[[7]]\"."

def api_call(messages):
    try_times = 0
    while try_times < 3:
        try:
            chat_completion = client.chat.completions.create(
                messages=messages,
                model="deepseek-chat",  # Or the model you are using
                # max_tokens=256,
                # temperature=0.3,
            )
            success = True
            # Print the entire response to debug
            # print(f"Full API Response: {chat_completion}")
            break
        except Exception as e:
            print(f"Error during API call: {e}")
            time.sleep(15)
            try_times += 1
            success = False
    if success:
        try:
            # Try to access the message content
            cleaned_string = chat_completion.choices[0].message.content.strip()
            return cleaned_string
        except KeyError:
            print("Error: API response structure unexpected.")
            return None
    else:
        return None

def read_combined_output(filepath):
    with open(filepath, 'r') as file:
        data = [json.loads(line) for line in file]
    return data

def evaluate_coherence_with_model(gen_caption, ref_caption):
    """
    Evaluate how similar the generated caption is to the reference caption using the model.
    The model will return a score based on the similarity.
    """
    # Create the prompt to ask the model to evaluate the similarity between the two captions
    prompt = f"Evaluate the similarity between the following two captions and give a score out of 10:\n\nGenerated caption: {gen_caption}\nReference caption: {ref_caption}\n\nScore (1-10):"
    
    # Send this prompt to the API
    messages = [
        {"role": "user", "content": [{"type": "text", "text": prompt}]}
    ]
    
    # Get the result from the model
    result = api_call(messages)
    
    # Assuming the model's response is in the format "Score: 7" or similar
    if result:
        match = re.search(r"Score:\s*(\d+)\s*/10", result)
        if match:
            return int(match.group(1))
    
    # If no score found, return 0 as a fallback
    return 0

def evaluate_models(assistant_data, instruction, reference_caption=None):
    # Extract the sentences (captions) and combine them into one large text
    story = " ".join(assistant_data['captions'])  # Concatenate all captions into one string
    
    # If it's coherence evaluation, compare with the reference_caption using the model
    if instruction == coherence_instruction and reference_caption:
        reference_caption_text = " ".join(reference_caption)  # Combine the reference captions into one string
        # Get the score from the model for the coherence between captions
        final_score = evaluate_coherence_with_model(story, reference_caption_text)
        return final_score
    else:
        # For style and engaging evaluation, only call API once
        messages = []
        messages.append({
            "role": "user",
            "content": [{"type": "text", "text": "Story text: " + story}]
        })
        
        messages.append({
            "role": "user",
            "content": [{"type": "text", "text": instruction}]
        })

        result = api_call(messages)
        return result


def main():
    # Load the combined output data
    combined_data = read_combined_output('/root/Story-result/combined_outputEBM.jsonl')  # Replace with correct path
    val_data = read_combined_output('/root/Story-result/val_copy.jsonl')  # Replace with correct path
    
    # Define the three instructions
    instructions = [style_instruction, engage_instruction, coherence_instruction]
    metrics = ['style', 'engaging', 'coherence']

    # Prepare the output data
    output_data = []

    # Loop through each item and evaluate based on three instructions
    for idx, item in enumerate(combined_data):
        print(f"Evaluating item {idx + 1}...")
        reference_caption = val_data[idx]['captions']  # Get the reference captions from the val dataset

        result_for_item = {
            'item_id': idx + 1,
            'scores': {}
        }

        # For each instruction, evaluate the generated captions and store the results
        for instruction_idx, instruction in enumerate(instructions):
            print(f"Evaluating with {metrics[instruction_idx]} instruction...")

            # Evaluate the generated captions based on the instruction
            if instruction == coherence_instruction:
                final_score = evaluate_models(item, instruction, reference_caption)
            else:
                final_score = evaluate_models(item, instruction)

            result_for_item['scores'][metrics[instruction_idx]] = final_score

        # Add the result for this item to the output_data
        output_data.append(result_for_item)

    # Save the results to a jsonl file
    with open('./output_hae.jsonl', 'w') as output_file:
        for entry in output_data:
            output_file.write(json.dumps(entry) + '\n')

    print("Results saved to output_hae.jsonl")

main()
